Binary tree tilt [Recursion]

Time: O(N); Space: O(N); easy

Given a binary tree, return the tilt of the whole tree.

The tilt of a tree node is defined as the absolute difference between the sum of all left subtree node values and the sum of all right subtree node values. Null node has tilt 0.

The tilt of the whole tree is defined as the sum of all nodes’ tilt.

Example 1:

  1
 / \
2   3

Input: root = {TreeNode} [1,2,3]

Output: 1

Explanation:

  • Tilt of node 2 : 0

  • Tilt of node 3 : 0

  • Tilt of node 1 : |2-3| = 1

  • Tilt of binary tree : 0 + 0 + 1 = 1

Example 2:

    1
   /
  1
 / \
2   3

Input: root = {TreeNode} [1,1,#,2,3]

Output:7

Notes:

  • The sum of node values in any subtree won’t exceed the range of 32-bit integer.

  • All the tilt values won’t exceed the range of 32-bit integer.

Hints:

  1. Don’t think too much, this is an easy problem. Take some small tree as an example.

  2. Can a parent node use the values of its child nodes? How will you implement it?

  3. May be recursion and tree traversal can help you in implementing.

  4. What about postorder traversal, using values of left and right childs?

[1]:
class TreeNode(object):
    def __init__(self, x):
        self.val = x
        self.left = None
        self.right = None

1. Using Recursion

Algorithm

From the problem statement, it is clear that we need to find the tilt value at every node of the given tree and add up all the tilt values to obtain the final result. To find the tilt value at any node, we need to subtract the sum of all the nodes in its left subtree and the sum of all the nodes in its right subtree.

Thus, to find the solution, we make use of a recursive function traverse which when called from any node, returns the sum of the nodes below the current node including itself. With the help of such sum values for the right and left subchild of any node, we can directly obtain the tilt value corresponding to that node.

See also: https://leetcode.com/problems/binary-tree-tilt/solution/

[2]:
class Solution1(object):
    """
    Time: O(N)
    Space: O(N)
    """
    def findTilt(self, root):
        """
        :type root: TreeNode
        :rtype: int
        """
        def postOrderTraverse(root, tilt):
            if not root:
                return 0, tilt

            left, tilt = postOrderTraverse(root.left, tilt)
            right, tilt = postOrderTraverse(root.right, tilt)

            tilt += abs(left-right)

            return left + right + root.val, tilt

        return postOrderTraverse(root, 0)[1]
[3]:
s = Solution1()

root = TreeNode(1)
root.left, root.right = TreeNode(2), TreeNode(3)
assert s.findTilt(root) == 1

root = TreeNode(1)
root.left = TreeNode(1)
root.left.left, root.left.right = TreeNode(2), TreeNode(3)
assert s.findTilt(root) == 7